import sys
import os
import datetime
import json
import argparse
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
import sklearn.metrics

from config import PATH

def group_shares(df,dummycol,groupcol):
    """Computes the share of dummycol==1 for each group given by the column groupcol"""
    dummy = df.groupby(groupcol)[dummycol].agg('sum')
    total = df.groupby(groupcol)[dummycol].agg('count')
    share = dummy/total
    newdf = pd.DataFrame({dummycol:dummy,'total':total,'share':share})
    return newdf


def compute_group_shares(df,varx,ipc):
    if ipc == "ipc4":
        #IPC4
        df["ipc3"] = df.ipc4.str.slice(stop=3)
        df["G05G06"] = ((df.ipc3 == "G05") | (df.ipc3 == "G06")).astype(int)
        df = df.loc[df.groupby("appln_nr").G05G06.transform('max') == 1]
        del df["ipc3"]
        del df["G05G06"]

    merge = None
    for var in varx:
        g = group_shares(df,var,ipc)
        g["share_{}".format(var)] = g["share"]
        del g["share"]
        if merge is None:
            merge = g
        else:
            del g["total"]
            merge = merge.merge(g,left_index=True,right_index=True,how="left")

    return merge

class IPC4Pair(object):
    def __init__(self,ipc1,ipc2):
        self.ipc = {ipc1:False,ipc2:False}

    def matched(self):
        v = list(self.ipc.values())
        return v[0] and v[1]

class IPC4Pairs(object):
    def __init__(self,pairs):
        self._pairs = []
        self._map = {}
        self._codes = []
        for (ipc1,ipc2) in pairs:
            pair = IPC4Pair(ipc1,ipc2)
            self._pairs.append(pair)
            self._map_ipc(ipc1,pair)
            self._map_ipc(ipc2,pair)

        self._flaggedPairs = []

    def _map_ipc(self,ipc,pair):
        if ipc not in self._map:
            self._map[ipc] = []
        self._map[ipc].append(pair)

    def flagCode(self,code):
        if code not in self._map:
            return False

        self._codes.append(code)
        pairs = self._map[code]
        for pair in pairs:
            pair.ipc[code] = True
            if pair.matched():
                return True
        return False

    def reset(self):
        for code in self._codes:
            pairs = self._map[code]
            for pair in pairs:
                pair.ipc[code] = False
        self._codes.clear()

class IPCCassificationSpecification(object):
    def __init__(self,ipc6_set,ipc4_set,ipc4_pairs):
        self.ipc6_set = ipc6_set
        self.ipc4_set = ipc4_set
        self.ipc4_pairs = ipc4_pairs

class SequentialIPCClassificationState(object):
    def __init__(self,spec):
        self._spec = spec
        self._cur_pat = None
        self._patents = []
        self._i4p = IPC4Pairs(spec.ipc4_pairs)
        self._reset()
        self._G05G06_set = set(["G05","G06"])

# How to handle pairs?
# Due to resetting the pair dataframe this here is super slow
    def _reset(self):
#        self._pairdf["hasipc1"] = 0
#        self._pairdf["hasipc2"] = 0
#        self._pairdf["s"] = 0
        self._i4p.reset()
        self._haspair = False
        self._hasipc6 = False
        self._hasipc4 = False
        self._hasG05G06 = False
        self._done = False

    def _check(self,code):
        if not self._hasipc6:
            self._hasipc6 = code in self._spec.ipc6_set
        if self._hasipc6:
            return

        if not self._hasG05G06:
            ipc3 = code[:3]
            self._hasG05G06 = ipc3 in self._G05G06_set

        ipc4 = code[:4]
        if not self._hasipc4:
            self._hasipc4 = ipc4 in self._spec.ipc4_set

        if self._hasipc4 and self._hasG05G06:
            return

        self._haspair = self._i4p.flagCode(ipc4)

    def process(self,pat_num,code):
        if pat_num != self._cur_pat:
            if self._done:
                self._patents.append(self._cur_pat)
            self._cur_pat = pat_num
            self._reset()

        if not self._done:
            self._check(code)
            if self._hasipc6 or (self._hasipc4 and self._hasG05G06) or self._haspair:
                self._done = True

    def finish(self):
        if self._done:
            self._patents.append(self._cur_pat)
        self._cur_pat = None
        self._reset()

    def patents(self):
        return self._patents

def sequentialClassify(codes,specs):
    states = {}
    for n in specs:
        states[n] = SequentialIPCClassificationState(specs[n])

    n = codes.shape[0]
    start = datetime.datetime.now()
    eta = ""
    for i in range(n):
        if i % 100000 == 1:
            p = float(i) / n
            delta = datetime.datetime.now() - start
            eta = "(ETA: {})".format(delta*(1-p)/p)
        if i % 1000 == 0:
            p = float(i) / n
            sys.stdout.write("Classifying... status: {:>12}/{}: {:.2f}%  {}   \r".format(i,n,round(p*100,2),eta))
        pat = codes.iat[i,0]
        code = codes.iat[i,1]
        for state in states:
            states[state].process(pat,code)

    for state in states:
        states[state].finish()

    res = {}
    for state in states:
        res[state] = states[state].patents()

    return res

def aggregate_ipc6(df,N=100):
    """Expects a ("appln_nr","ipc6",...) data frame and aggregates small ipc6 classes (less than N patents)
    to a new class AAAAXX, where AAAA is the IPC4 class"""

    features = df.copy()
    del features["ipc6"]
    features.drop_duplicates('appln_nr',inplace=True)

    totals = df.groupby("ipc6").appln_nr.agg('count')
    print("Creating sets...")
    s = df.groupby("ipc6").appln_nr.agg(set)
    d = {}
    for x in s.index:
        d[x] = s[x]

    ipc6x = {}

    print("Iterating...")
    for ipc6 in totals.index:
        if totals[ipc6] < N:
            code = "{}XX".format(ipc6[:4])
            if not code in d:
                d[code] = set()
            d[code].update(d[ipc6])
            del d[ipc6]
            if not code in ipc6x:
                ipc6x[code] = set()
            ipc6x[code].add(ipc6)
    print("Creating df...")

    rows = []
    for ipc6 in d:
        for appln_nr in d[ipc6]:
            row = [appln_nr,ipc6]
            rows.append(row)

    newdf = pd.DataFrame(rows)
    newdf.columns = ("appln_nr","ipc6")
    print("Almost done... merging in features...")
    newdf = newdf.merge(features,on="appln_nr")
    print("Done.")
    for code in ipc6x:
        ipc6x[code] = list(ipc6x[code])

    return (newdf,ipc6x)

def merge_techfields(df,ipc2tf,ipc,ipc_type=None):
    # Create a mapping of techn_field_nr to name and sector of the field
    tf = ipc2tf.copy()
    tf = tf[["techn_field_nr","techn_field","techn_sector"]]
    tf = tf.drop_duplicates().set_index("techn_field_nr")
    ipc2tf_orig = ipc2tf.copy()

    if not ipc_type:
        assert(ipc in ["ipc4","ipc6"])
        ipc_type = ipc

    if ipc_type == "ipc4":
        # Extract most frequent ipc4->techn_field_nr mapping
        ipc2tf["ipc4"] = ipc2tf.ipc_maingroup_symbol.str[:4]
        ipc4totf = ipc2tf.groupby("ipc4").agg({'techn_field_nr':lambda x:x.value_counts().idxmax()})
        mapping = ipc4totf.merge(tf,left_on="techn_field_nr",right_index=True,how="left")
        o = df.merge(mapping,left_on=ipc,right_index=True,how="left")
        return o
    elif ipc_type == "ipc6":
        # Proceed in two stages.
        #1. If we have a mapping for the ipc6 code, apply the mapping
        m1 = df.merge(ipc2tf,left_on=ipc,right_on="ipc_maingroup_symbol",how="left")
        df1 = m1.loc[~m1.techn_field_nr.isnull()]
        #2. For all the others, use the IPC4 code for merging
        df["ipc4"] = df.ipc6.str[:4]
        df2 = merge_techfields(df.loc[m1.techn_field_nr.isnull()],ipc2tf_orig,"ipc4")
        del df2["ipc4"]
        o = pd.concat([df1,df2],sort=False)
        return o
    else:
        raise Exception("ipc_type='{}' not implemented!".format(ipc))

def merge_techfields_combo(df,ipc2tf):
    df = merge_techfields(df,ipc2tf,"ipc1",ipc_type="ipc4")
    df.rename(columns={'techn_field_nr':'techn_field_nr_1',
                       'techn_field':'techn_field_1',
                       'techn_sector':'techn_sector_1'},inplace=True)
    df = merge_techfields(df,ipc2tf,"ipc2",ipc_type="ipc4")
    df.rename(columns={'techn_field_nr':'techn_field_nr_2',
                       'techn_field':'techn_field_2',
                       'techn_sector':'techn_sector_2'},inplace=True)
    df = df.reset_index().set_index('ipc1').reset_index() #quick way to re-rearrange ipc1 and ipc2
    return df


def restrict_ipc_codes(df,no_exceptions=False):
    "Take an ipc data frame and restrict it to appropriate technical fields and large enough classes."
    restricted_fields = ["handling","machine tools","textile and paper machines","other special machines"]
    ipc3_excluded = ["F41","F42"]
    if no_exceptions:
        ipc4_additional = []
        ipc6_additional = []
    else:
        ipc4_additional = ["B42C","B07C"]
        ipc6_additional = ["G05B19", "B62D65"]

    # keep codes only if A) techn_field in an appropriate field, or B) is in B42C or B07C, or C) is G05B19 or B62D65 (for ipc6).
    if "techn_field" in df.columns:
        # For ipc6, ipc4, and ipc6XX data
        if "ipc6" in df.columns:
            ipc = "ipc6"
        elif "ipc4" in df.columns:
            ipc = "ipc4"
        else:
            raise ValueError("Cannot find ipc columns in data frame")

        # remove F41 and F42 codes
        df = df.loc[(~df[ipc].str[:3].isin(ipc3_excluded))]
        appr_field = df.techn_field.str.lower().isin(restricted_fields)
        in_additional_ipc4 = df[ipc].str[:4].isin(ipc4_additional)
        is_additional_ipc6 = df[ipc].isin(ipc6_additional)
        df = df.loc[appr_field | in_additional_ipc4 | is_additional_ipc6]
    else:
        # for ipc4 pairs, either of the two fields needs to be in the list

        # remove F41 and F42 codes
        df = df.loc[(~df.ipc1.str[:3].isin(ipc3_excluded)) | (~df.ipc2.str[:3].isin(ipc3_excluded))]
        appr_field = df.techn_field_1.str.lower().isin(restricted_fields)
        appr_field = appr_field | df.techn_field_2.str.lower().isin(restricted_fields)
        in_additional_ipc4 = df.ipc1.isin(ipc4_additional) | df.ipc2.isin(ipc4_additional)
        notY = (df.ipc1.str[:1] != "Y") & (df.ipc2.str[:1] != "Y")
        df = df.loc[(appr_field | in_additional_ipc4) & notY]

    #only look at large enough codes
    df = df.loc[df.total >= 100]
    return df

def placebo_restrict_ipc_codes(df,kind,no_exceptions=False):
    "Take an ipc data frame and restrict it to placebo fields"

    if kind == "automation":
        return restrict_ipc_codes(df,no_exceptions=no_exceptions)
    elif kind == "pharma":
        restricted_fields = ["pharmaceuticals"]
    elif kind == "chemistry":
        restricted_fields = ["organic fine chemistry","macromolecular chemistry, polymers"]

    # keep codes only if techn_field in an appropriate field
    if "techn_field" in df.columns:
        # For ipc6, ipc4, and ipc6XX data
        if "ipc6" in df.columns:
            ipc = "ipc6"
        elif "ipc4" in df.columns:
            ipc = "ipc4"
        else:
            raise ValueError("Cannot find ipc columns in data frame")

        # remove F41 and F42 codes
        appr_field = df.techn_field.str.lower().isin(restricted_fields)
        df = df.loc[appr_field]
    else:
        # for ipc4 pairs, either of the two fields needs to be in the list

        appr_field = df.techn_field_1.str.lower().isin(restricted_fields)
        appr_field = appr_field | df.techn_field_2.str.lower().isin(restricted_fields)
        df = df.loc[appr_field]

    #only look at large enough codes
    df = df.loc[df.total >= 100]
    return df

def generate_stata_commands(df,filename=None,ipc6x_mapping=None):
#    cmds = ["use common_data/patstat_2016b/ipc_codes, clear"]
    cmds = ["use automation/datasets/cipc_codes.dta, clear"]
    cmds.append("ren cipc6 ipc_code")

    if "ipc6" in df.columns:
        var = "ipc6"
    elif "ipc4" in df.columns:
        var = "ipc4"
    elif "ipc1" in df.columns:
        var = "ipc4_pairs"
    else:
        raise ValueError("Unknown format of data frame. Except ipc4/ipc6 or ipc1,ipc2 columns")

    if var == "ipc4":
        cmds.append('gen g05g06 = strmatch(ipc_code,"G05*") | strmatch(ipc_code,"G06*")')
        cmds.append('bys appln_id : egen has_g05g06 = max(g05g06)')
        cmds.append('keep if has_g05g06')
        cmds.append('drop g05g06 has_g05g06')

    cmds.append('gen {} = 0'.format(var))
    if var in ["ipc6","ipc4"]:
        for ipc in df[var]:
            cmds.append('* {} {} '.format(var,ipc))
            if ipc.endswith("XX"):
                for ipc6x in ipc6x_mapping[ipc]:
                    cmds.append('replace {} = {} | strmatch(ipc_code,"{}*")'.format(var,var,ipc6x))
            else:
                cmds.append('replace {} = {} | strmatch(ipc_code,"{}*")'.format(var,var,ipc))
    else:
        for i in df.index:
            combination = df.loc[i]
            ipc1 = combination.ipc1
            ipc2 = combination.ipc2
            cmds.append('* {} {}-{} '.format(var,ipc1,ipc2))
            cmds.append('gen ipc1 = strmatch(ipc_code,"{}*")'.format(ipc1))
            cmds.append('gen ipc2 = strmatch(ipc_code,"{}*")'.format(ipc2))
            cmds.append('bys appln_id : egen ipc1m = max(ipc1)')
            cmds.append('bys appln_id : egen ipc2m = max(ipc2)')
            cmds.append('replace {} = {} | (ipc1m & ipc2m)'.format(var,var))
            cmds.append('drop ipc1* ipc2*')

    cmds.append('keep if {}'.format(var))
    cmds.append('keep appln_id'.format(var))
    cmds.append('duplicates drop')
    cmds.append('compress')
    cmds.append('save {}, replace'.format(filename))
    cmds = "\n".join(cmds)
    print(cmds)
    return(cmds)


class SetEncoder(json.JSONEncoder):
    def default(self,obj):
        if isinstance(obj,set):
            return list(obj)
        return json.JSONEncoder.default(self,obj)

def _decoder(obj):
    if isinstance(obj,list):
        return set(obj)
    return obj


def compute_ipc4_combinations(df,N=100,features=None):
    """This function takes a dataframe with columns ('appln_nr','ipc4',dummyfeature1,dummyfeature2,...) an an input
    and then computes ipc4 combinations which have at least N (=100) patents.
    For each IPC4 combination, the total of patents and the total of dummyfeatrureN=1 patents is computed."""

    if features is None:
        features = list(filter(lambda col: col not in ["appln_nr","ipc4"],df.columns))

    print("Computing maps...")
    fmap = {}
    #Computes a map apn->feature for each feature
    for var in features:
        fmap[var] = df.groupby("appln_nr")[var].agg("first")

#    apn = df.groupby("appln_nr").isauto.agg("first")

    # computes totals per ipc4 code
    totals = pd.DataFrame({'totals':df.groupby("ipc4")[features[0]].agg('count')})

    # our base set of ipc_codes
    ipc_codes = set(totals.index[totals.totals>N])

    # compute sets of applications for each ipc code
    print("Computing IPC sets...")
    ipc2apn = {}
    i = 0
    tot = len(ipc_codes)
    for ipc in ipc_codes:
        i += 1
        print("{}/{}: {}".format(i,tot,ipc))
        ipc2apn[ipc] = set(df.appln_nr.loc[df.ipc4 == ipc])

    print("Iterating over IPC combinations...")
    # go through all ipc combinations and save combinations with at least N patents
    ipc_comb = set()
    i = 0
    tot = len(ipc_codes)**2
    rows = []
    for ipc1 in ipc_codes:
        for ipc2 in ipc_codes:
            i += 1
            if ipc1==ipc2 or (ipc2,ipc1) in ipc_comb:
                continue
            intersection = ipc2apn[ipc1].intersection(ipc2apn[ipc2])
            if len(intersection) < N:
                continue

            v = (1+len(features))*[None]
            for (j,var) in enumerate(features):
                v[j] = sum(fmap[var][list(intersection)]) # TODO: JF: I changed it to a list from [intersection]; deprecated in newer pandas
            v[len(features)] = len(intersection)

            print("{}/{}: Found {}-{}".format(i,tot,ipc1,ipc2))
            ipc_comb.add((ipc1,ipc2))
            l = [ipc1,ipc2]
            l.extend(v)
            rows.append(l)

    newdf = pd.DataFrame(rows)
    cols = ["ipc1","ipc2"]
    cols.extend(features)
    cols.append("total")

    newdf.columns = tuple(cols)

    for var in features:
        newdf["share_{}".format(var)] = newdf[var]/newdf.total

    return newdf


def plot_histogram(series,q):
    P = q.index
    plt.hist(series,bins=20)
    tr = plt.gca().get_xaxis_transform()
    for p in P:
        plt.axvline(q[p],c="black",ls='--')
        plt.text(q[p]-0.015,1.025,f"p{int(p*100)}",transform=tr)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Classify patents according to IPC codes')
    parser.add_argument("--path", type=str, help="Path argument")
    parser.add_argument("--task", type=str, help="Task to perform")
    args = parser.parse_args()
    
    PATH = args.path if args.path else PATH
    task = args.task

    # task = sys.argv[1] if len(sys.argv) > 1 else ""
    if task == "group_shares":
        for XX in ["XX",""]:
            for ipc in ["ipc4","ipc6"]:
                if ipc == "ipc4" and XX=="XX":
                    continue
                I = "{}/appln_{}{}.csv".format(PATH,ipc,XX)
                O = "{}/{}{}.csv".format(PATH,ipc,XX)

                df = pd.read_csv(I,dtype={'appln_nr':str})

                # add a automationnolabor column (any - labor)
                automation_components_wo_labor = ["automat","robot","CNC","CADCAM","threedee","flexman","PLC"]
                df["automationnolabor"] = df[automation_components_wo_labor].max(axis=1)

                varx = filter(lambda c : c not in ["appln_nr",ipc],list(df.columns))

                dfm = compute_group_shares(df,varx,ipc)
                dfm.to_csv(O)

    elif task == "aggregate_ipc6":
        I = "{}/appln_ipc6.csv".format(PATH)
        O = "{}/appln_ipc6XX.csv".format(PATH)
        O2 = "{}/ipc6x_mapping.json".format(PATH)
        df = pd.read_csv(I,dtype={'appln_nr':str})
        (dfa,mapping) = aggregate_ipc6(df,N=100)
        dfa.to_csv(O,index=False)
        print("Saving ipc6x mapping...")
        with open(O2,"w") as f:
            f.write(json.dumps(mapping))
        O3 = "{}/ipc6XX_mapping.csv".format(PATH)
        data = []
        for ipc6XX in mapping:
            for ipc6 in mapping[ipc6XX]:
                data.append([ipc6,ipc6XX])
        dfm = pd.DataFrame(data,columns=["IPC6","IPC6XX"])
        dfm.to_csv(O3,index=False)

    elif task == "merge_techfields":
        tf = pd.read_csv("patstat/ipc_techn_fields.csv")
        tf["ipc_maingroup_symbol"] = tf["ipc_maingroup_symbol"].str.replace(" ","")

        df = pd.read_csv("{}/ipc6.csv".format(PATH))

        for XX in ["XX",""]:
            for ipc in ["ipc4","ipc6"]:
                if ipc == "ipc4" and XX=="XX":
                    continue
                I = "{}/{}{}.csv".format(PATH,ipc,XX)
                O = "{}/{}{}_tf.csv".format(PATH,ipc,XX)
                df = pd.read_csv(I)
                df = merge_techfields(df,tf,ipc)
                print("Saving {}...".format(O))
                df.to_csv(O,index=False)
        I = "{}/ipc4_pairs.csv".format(PATH)
        O = "{}/ipc4_pairs_tf.csv".format(PATH)

        df = pd.read_csv(I)
        df = merge_techfields_combo(df,tf)
        del df["index"]
        print("Saving {}...".format(O))

        df.to_csv(O,index=False)

    elif task == "visualize_classification":
        key = "anyclassification"
        cl = "share_{}".format(key)
        d = "{}/img".format(PATH)
        plt.close('all')
        pu = pw = None
        for tfr in ["_restricted",""]:
            if not os.path.isdir(d):
                os.mkdir(d)
            for XX in ["XX",""]:
                for ipc in ["ipc6","ipc4"]:
                    if ipc == "ipc4" and XX=="XX":
                        continue
                    I = "{}/{}{}_tf.csv".format(PATH,ipc,XX)
                    df = pd.read_csv(I)
                    if tfr == "_restricted":
                        df = restrict_ipc_codes(df)
                    df.rename(columns={cl:'share'},inplace=True)
                    print("{}{} :: Summary statistics".format(ipc,XX))
                    print("Quantiles [unweighted]:")
                    print(df.share.quantile([0.25,0.5,0.6,0.75,0.8,0.85,0.9,0.95,0.99]))
                    print("Quantiles [weighted]:")
                    print(df.share.repeat(df.total).quantile([0.25,0.5,0.6,0.75,0.8,0.85,0.9,0.95,0.99]))
                    print("Mean [unweighted]:")
                    print(df.share.mean())
                    print("Mean [weighted]:")
                    print(df.share.repeat(df.total).mean())
                    print("Std dev. [unweighted]:")
                    print(df.share.std())
                    print("Std dev. [weighted]:")
                    print(df.share.repeat(df.total).std())

                    if pu is None:
                        pu = df.share.quantile([0.90,0.95])
                        pw = df.share.repeat(df.total).quantile([0.90,0.95])
                    plot_histogram(df.share,pu)
                    plt.savefig("{}/{}{}_unweighted{}.png".format(d,ipc,XX,tfr))
                    plt.clf()
                    plot_histogram(df.share.repeat(df.total),pw)
                    plt.savefig("{}/{}{}_weighted{}.png".format(d,ipc,XX,tfr))
                    plt.clf()

            I = "{}/ipc4_pairs_tf.csv".format(PATH)
            df = pd.read_csv(I)
            if tfr == "_restricted":
                df = restrict_ipc_codes(df)
            df.rename(columns={cl:'share'},inplace=True)
            print("ipc4 pairs :: Summary statistics".format(ipc,XX))
            print("Quantiles [unweighted]:")
            print(df.share.quantile([0.25,0.5,0.6,0.75,0.8,0.85,0.9,0.95,0.99]))
            print("Quantiles [weighted]:")
            print(df.share.repeat(df.total).quantile([0.25,0.5,0.6,0.75,0.8,0.85,0.9,0.95,0.99]))
            print("Mean [unweighted]:")
            print(df.share.mean())
            print("Mean [weighted]:")
            print(df.share.repeat(df.total).mean())
            print("Std dev. [unweighted]:")
            print(df.share.std())
            print("Std dev. [weighted]:")
            print(df.share.repeat(df.total).std())
            plot_histogram(df.share,pu)
            plt.savefig("{}/ipc4_pairs_unweighted{}.png".format(d,tfr))
            plt.clf()
            plot_histogram(df.share.repeat(df.total),pw)
            plt.savefig("{}/ipc4_pairs_weighted{}.png".format(d,tfr))
            plt.clf()


    elif task == "classify":
        varx = ["ipc6XX","ipc4","ipc4_pairs"]
        
        # In the following, we are going to define a set of classifications.
        # A classification is defined by:
        # - a set of ipc6 codes
        # - a set of ipc4 codes (interacted with G05&G06)
        # - a set of ipc4 combinations
        # Classifications are stored in the map d:
        d = {}

        # The first set of classifications are auto90 and auto95:
        Q = [80,90,95]
        for q in Q:
            name = "automation{}".format(q)
            d[name] = IPCCassificationSpecification(None,None,None)
            thresh = None
            for var in varx:
                I = "{}/{}_tf.csv".format(PATH,var)

                df = pd.read_csv(I)
                df = restrict_ipc_codes(df)
                df.rename(columns={'share_anyclassification':'share'},inplace=True)
                df.to_csv("{}/codes_auto{}_{}_n.csv".format(PATH,q,var), index=False)
                if not thresh:
                    thresh = df.share.quantile(q/100.0)
                    print("Automation q={}%, thresh={}".format(q,thresh))                  
                df = df.loc[df.share>=thresh]
                df.to_csv("{}/classified_auto{}_{}_n.csv".format(PATH,q,var), index=False)
                if var=="ipc6XX":
                    d[name].ipc6_set = set(df["ipc6"])
                elif var=="ipc4":
                    d[name].ipc4_set = set(df["ipc4"])
                else:
                    d[name].ipc4_pairs = list(zip(df["ipc1"],df["ipc2"]))

        # Do a no-exceptions version of above:
        Q = [95]
        for q in Q:
            name = "automationX{}".format(q)
            d[name] = IPCCassificationSpecification(None,None,None)
            thresh = None
            for var in varx:
                I = "{}/{}_tf.csv".format(PATH,var)

                df = pd.read_csv(I)
                df = restrict_ipc_codes(df,no_exceptions=True)
                df.rename(columns={'share_anyclassification':'share'},inplace=True)
                if not thresh:
                    thresh = df.share.quantile(q/100.0)
                    print("AutomationX q={}%, thresh={}".format(q,thresh))
                df = df.loc[df.share>=thresh]
                if var=="ipc6XX":
                    d[name].ipc6_set = set(df["ipc6"])
                elif var=="ipc4":
                    d[name].ipc4_set = set(df["ipc4"])
                else:
                    d[name].ipc4_pairs = list(zip(df["ipc1"],df["ipc2"]))

        # Here we look at "automation without labor" (i.e., {any}\{labor})
        Q = [80,90,95]
        for q in Q:
            name = "automationnolabor{}".format(q)
            d[name] = IPCCassificationSpecification(None,None,None)
            thresh = None
            for var in varx:
                I = "{}/{}_tf.csv".format(PATH,var)

                df = pd.read_csv(I)
                df = restrict_ipc_codes(df)
                df.rename(columns={'share_automationnolabor':'share'},inplace=True)
                if not thresh:
                    thresh = df.share.quantile(q/100.0)
                    print("Automationnolabor q={}%, thresh={}".format(q,thresh))
                df = df.loc[df.share>=thresh]
                if var=="ipc6XX":
                    d[name].ipc6_set = set(df["ipc6"])
                elif var=="ipc4":
                    d[name].ipc4_set = set(df["ipc4"])
                else:
                    d[name].ipc4_pairs = list(zip(df["ipc1"],df["ipc2"]))

        # Here we look at subcomponents of auto95 (auto90)
        # Note that we determine the threshold based on auto95
        Q = [80,90,95]
        for q in Q:
            shares = ["CNC","automat","robot","labor"]
            I = "{}/ipc6XX_tf.csv".format(PATH)
            df = pd.read_csv(I)
            df = restrict_ipc_codes(df)
            thresh = df.share_anyclassification.quantile(q/100.0)
            print("Subcomponents; using: Automation q={}%, thresh={}".format(q,thresh))
            for share in shares:
                name = "{}{}".format(share,q)
                d[name] = IPCCassificationSpecification(None,None,None)
                for var in varx:
                    I = "{}/{}_tf.csv".format(PATH,var)
                    df = pd.read_csv(I)
                    df = restrict_ipc_codes(df)
                    df.rename(columns={'share_{}'.format(share):'share'},inplace=True)
                    df = df.loc[df.share>=thresh]
                    if var=="ipc6XX":
                        d[name].ipc6_set = set(df["ipc6"])
                    elif var=="ipc4":
                        d[name].ipc4_set = set(df["ipc4"])
                    else:
                        d[name].ipc4_pairs = list(zip(df["ipc1"],df["ipc2"]))

        # In the following, we define placebos as any patent having no code in the top
        # percentiles of ipc codes (not restricted on any field)
        tf = pd.read_csv("patstat/ipc_techn_fields.csv")
        tf["ipc_maingroup_symbol"] = tf["ipc_maingroup_symbol"].str.replace(" ","")

        # Here we define placebos
        Q = [60,90,95]
        for field in ["chemistry","pharma","automation"]:
            for q in Q:
                name = "p{}{}_complement".format(field[:4],q)
                d[name] = IPCCassificationSpecification(None,None,None)
                thresh = None
                for var in varx:
                    I = "{}/{}_tf.csv".format(PATH,var)
                    df = pd.read_csv(I)
                    dfauto = restrict_ipc_codes(df)
                    df = placebo_restrict_ipc_codes(df,field)
                    df.rename(columns={'share_anyclassification':'share'},inplace=True)
                    if not thresh:
                        thresh = dfauto.share_anyclassification.quantile(q/100.0)
                        print("Placebo {} q={}%, thresh={}".format(name[:5],q,thresh))
                    df = df.loc[df.share>=thresh]
                    if var=="ipc6XX":
                        d[name].ipc6_set = set(df["ipc6"])
                    elif var=="ipc4":
                        d[name].ipc4_set = set(df["ipc4"])
                    else:
                        d[name].ipc4_pairs = list(zip(df["ipc1"],df["ipc2"]))


        # Here we are figuring out how many patents are classified by ipc6, ipc4, or the pairs respectively.
        Q = [90,95]
        for q in Q:
            thresh = None
            for var in varx:
                name = "automation{}_{}".format(q,var)
                d[name] = IPCCassificationSpecification(None,None,None)
                I = "{}/{}_tf.csv".format(PATH,var)

                df = pd.read_csv(I)
                df = restrict_ipc_codes(df)
                df.rename(columns={'share_anyclassification':'share'},inplace=True)
                if not thresh:
                    thresh = df.share.quantile(q/100.0)
                    print("Automation q={}%, thresh={}".format(q,thresh))
                    txt_file_path = "{}/auto{}_{}.txt".format(PATH, q, var)
                    with open(txt_file_path, "w") as file:
                        file.write("Automation q={}%, thresh={}\n".format(q, thresh))
                df = df.loc[df.share>=thresh]
                d[name].ipc6_set = set()
                d[name].ipc4_set = set()
                d[name].ipc4_pairs = set()
                if var=="ipc6XX":
                    d[name].ipc6_set = set(df["ipc6"])
                elif var=="ipc4":
                    d[name].ipc4_set = set(df["ipc4"])
                else:
                    d[name].ipc4_pairs = list(zip(df["ipc1"],df["ipc2"]))

        print("Targets:")
        for name in d:
            print(" - {}".format(name))

        print("{}: Loading patents...".format(datetime.datetime.now()))
        cipc_codes = pd.read_csv("patstat/docdb_family_id_cipc_codes.csv")

        print("{}: Applying XX mapping...".format(datetime.datetime.now()))
        ipc6xx_mapping = pd.read_csv("{}/ipc6XX_mapping.csv".format(PATH))
        ipc6xx_mapping = ipc6xx_mapping.rename(columns={'IPC6':'cipc6','IPC6XX':'cipc6xx'})
        cipc6xx_codes = pd.merge(cipc_codes, ipc6xx_mapping, how = "left", on = "cipc6")
        cipc6xx_codes['cipc6'] = cipc6xx_codes['cipc6xx'].fillna(cipc6xx_codes['cipc6'])
        cipc6xx_codes = cipc6xx_codes[["docdb_family_id", "cipc6"]]

        patent_lists = sequentialClassify(cipc6xx_codes,d)

        for name in patent_lists:
            try:
                os.mkdir("{}/patent_lists".format(PATH))
            except:
                pass

            O = "{}/patent_lists/{}.csv".format(PATH,name)
            patent_list = patent_lists[name]
            df = pd.DataFrame({'docdb_family_id':patent_list})
            df.to_csv(O,index=False)
    elif task == "ipc4_pairs":
        df = pd.read_csv("{}/appln_ipc4.csv".format(PATH),dtype={'appln_nr':str})
        # TODO: JF: added again the nonlabor column also here (because it is not in the features list);
        # could be done in a more elegant way
        automation_components_wo_labor = ["automat","robot","CNC","CADCAM","threedee","flexman","PLC"]
        df["automationnolabor"] = df[automation_components_wo_labor].max(axis=1)
        df2 = compute_ipc4_combinations(df)
        df2.to_csv("{}/ipc4_pairs.csv".format(PATH),index=False)

    else:
        print("Unknown task: {}".format(task))
